import numpy as np
import pandas as pd
import optuna
from typing import List, Dict, Any, Optional, Tuple

from hypersense.importance.base_analyzer import BaseImportanceAnalyzer
from optuna.importance import PedAnovaImportanceEvaluator


class PedAnovaAnalyzer(BaseImportanceAnalyzer):
    def __init__(self, baseline_quantile: float = 0.1, evaluate_on_local: bool = True):
        """
        PED-ANOVA based hyperparameter importance analyzer (using optuna).

        Args:
            baseline_quantile: Baseline quantile for computing importance (e.g., 0.1 = top 10%).
            evaluate_on_local: Whether to compute local (default) or global importance.
        """
        super().__init__()
        self.baseline_quantile = baseline_quantile
        self.evaluate_on_local = evaluate_on_local
        self.evaluator = PedAnovaImportanceEvaluator(
            baseline_quantile=self.baseline_quantile,
            evaluate_on_local=self.evaluate_on_local,
        )

    def fit(self, configs: List[Dict[str, Any]], scores: List[float]) -> None:
        """
        Fit the PED-ANOVA analyzer on sampled configurations and their corresponding scores.
        """
        self._study = self._build_study(configs, scores)
        importance = self.evaluator.evaluate(self._study)

        self.feature_importances_ = importance
        self.interaction_importances_ = (
            {}
        )  # PED-ANOVA does not provide interaction importance.

    def explain(self, normalize: bool = True) -> Dict[str, float]:
        if self.feature_importances_ is None:
            raise ValueError("Importance scores not available. Call fit() first.")

        if not normalize:
            return self.feature_importances_

        total = sum(self.feature_importances_.values())
        if total == 0:
            return {k: 0.0 for k in self.feature_importances_}

        return {k: v / total for k, v in self.feature_importances_.items()}

    def explain_interactions(self) -> Dict[Tuple[str, str], float]:
        return {}

    def rank(self) -> List[str]:
        if self.feature_importances_ is None:
            raise ValueError("Importance scores not available. Call fit() first.")
        return sorted(
            self.feature_importances_,
            key=lambda k: self.feature_importances_[k],
            reverse=True,
        )

    def _build_study(
        self, configs: List[Dict[str, Any]], scores: List[float]
    ) -> optuna.Study:
        study = optuna.create_study(
            direction="maximize", sampler=optuna.samplers.RandomSampler(seed=42)
        )
        # Step 1: Infer each param's unified type
        unified_types = {}
        for config in configs:
            for k, v in config.items():
                if k in unified_types:
                    continue
                if isinstance(v, float):
                    unified_types[k] = "float"
                elif isinstance(v, int):
                    unified_types[k] = "int"
                elif isinstance(v, str):
                    unified_types[k] = "str"
                else:
                    raise ValueError(
                        f"Unsupported value type for param '{k}': {type(v)}"
                    )

        # Step 2: Infer fixed distributions (same across all trials)
        fixed_dists = {}
        for k in unified_types:
            all_values = [cfg[k] for cfg in configs if k in cfg]
            unique_values = list(set(all_values))
            if unified_types[k] == "float":
                low, high = min(unique_values), max(unique_values)
                fixed_dists[k] = optuna.distributions.FloatDistribution(
                    low=low, high=high
                )
            elif unified_types[k] == "int":
                low, high = min(unique_values), max(unique_values)
                fixed_dists[k] = optuna.distributions.IntDistribution(
                    low=low, high=high
                )
            elif unified_types[k] == "str":
                fixed_dists[k] = optuna.distributions.CategoricalDistribution(
                    choices=unique_values
                )

        # Step 3: Register trials
        for config, score in zip(configs, scores):
            trial = study.ask(fixed_distributions=fixed_dists)
            study.tell(trial, score)

        return study
